/*
 * Copyright (c) 2017-2022 TIBCO Software Inc. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you
 * may not use this file except in compliance with the License. You
 * may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied. See the License for the specific language governing
 * permissions and limitations under the License. See accompanying
 * LICENSE file.
 */
package org.apache.spark.sql

import java.util.TimeZone

import io.snappydata.benchmark.TPCH_Queries
import io.snappydata.benchmark.snappy.tpch.QueryExecutor
import io.snappydata.benchmark.snappy.{SnappyAdapter, TPCH}
import io.snappydata.{PlanTest, Property, SnappyFunSuite}
import org.scalatest.BeforeAndAfterEach

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sort}
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.util.Benchmark

class IndexTest extends SnappyFunSuite with PlanTest with BeforeAndAfterEach {

  override def beforeAll(): Unit = {
    // System.setProperty("org.codehaus.janino.source_debugging.enable", "true")
    System.setProperty("spark.sql.codegen.comments", "true")
    System.setProperty("spark.testing", "true")
    super.beforeAll()
  }

  override def afterAll(): Unit = {
    // System.clearProperty("org.codehaus.janino.source_debugging.enable")
    System.clearProperty("spark.sql.codegen.comments")
    System.clearProperty("spark.testing")
    Property.PartitionPruning.set(snc.conf, true)
    super.afterAll()
  }

  test("test PutInto and DeleteFrom") {

    snc.sql("create table checko (col1 Integer primary key, col2 Integer) using row options " +
        "(partition_by 'col1') ")

    val data = sc.parallelize(Seq(Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4), Row(5, 5),
      Row(6, 6)))

    val struct = StructType(
      StructField("i", IntegerType, true) ::
          StructField("b", IntegerType, false) :: Nil)

    val df = snc.createDataFrame(data, struct)
    import snappy._
    df.write.putInto("APP.CHECKO")

    assert(snc.sql("select * from checko").count() == 6)

    df.selectExpr("i as col1", "b as col2").where("i > 4").write.deleteFrom("APP.CHECKO")

    assert(snc.sql("select * from checko").count() == 4)

    df.filter("b < 2").selectExpr("i as col1").write.deleteFrom("APP.CHECKO")

    assert(snc.sql("select * from checko").count() == 3)
  }

  test("check varchar index") {
    /*
        snc.sql("Create table ODS.ORGANIZATIONS(" +
            "org_id bigint GENERATED BY DEFAULT AS IDENTITY  NOT NULL," +
            "ver bigint NOT NULL," +
            "client_id bigint NOT NULL," +
            "org_nm  varchar(80), " +
            "org_typ_ref_id bigint NOT NULL," +
            "descr LONG VARCHAR," +
            "empr_tax_id varchar(25)," +
            "web_site varchar(100)," +
            "eff_dt DATE," +
            "expr_dt DATE," +
            "vld_frm_dt " +
            "TIMESTAMP NOT NULL," +
            "vld_to_dt TIMESTAMP," +
            "src_sys_ref_id LONG VARCHAR NOT NULL," +
            "src_sys_rec_id LONG VARCHAR," +
            "PRIMARY KEY (client_id,org_id)" +
            ")" +
            "using row options (partition_by 'org_id')" +
            "")
    */
    snc.sql("Create table ODS.ORGANIZATIONS(" +
        "org_id bigint GENERATED BY DEFAULT AS IDENTITY  NOT NULL," +
        "client_id bigint NOT NULL," +
        "descr LONG VARCHAR," +
        "PRIMARY KEY (client_id,org_id)" +
        ") " +
        "using row options (partition_by 'org_id')" +
        "")

    snc.sql("create index ods.idx_org on ODS.ORGANIZATIONS (CLIENT_ID, DESCR)")

    snc.sql("insert into ods.organizations(client_id, descr) values(8006, 'EL')")
    snc.sql("update ods.organizations set descr = 'EL                                            " +
        "                                                                      " +
        "  ' where client_id = 8006")
    snc.sql("select * from ods.organizations").collect()
    snc.sql("select client_id, descr from ods.organizations where client_id = 8006").collect()
  }

  test("tpch queries") {
    val qryProvider = new TPCH with SnappyAdapter

    val queries = Array("1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11",
      "12", "13", "14", "15", "16", "17", "18", "19",
      "20", "21", "22")

    TPCHUtils.createAndLoadTables(snc, true)

    val existing = snc.getConf(io.snappydata.Property.EnableExperimentalFeatures.name)
    snc.setConf(io.snappydata.Property.EnableExperimentalFeatures.name, "true")

    for ((q, i) <- queries.zipWithIndex) {
      val qNum = i + 1
      val (expectedAnswer, _) = qryProvider.execute(qNum, str => {
        snc.sql(str)
      })
      var queryToBeExecuted = TPCH_Queries.getQuery(q, false, true)
      val (newAnswer, df) = QueryExecutor.queryExecution(q, queryToBeExecuted, snc, false)
      val isSorted = df.logicalPlan.collect { case s: Sort => s }.nonEmpty
      QueryTest.sameRows(expectedAnswer, newAnswer, isSorted).map { results =>
        s"""
           |Results do not match for query: $qNum
           |Timezone: ${TimeZone.getDefault}
           |Timezone Env: ${sys.env.getOrElse("TZ", "")}
           |
           |${df.queryExecution}
           |== Results ==
           |$results
       """.stripMargin
      }
      logInfo(s"Done $qNum")
    }
    snc.setConf(io.snappydata.Property.EnableExperimentalFeatures.name, existing)

  }

  ignore("Benchmark tpch") {

    try {
      val queries = Array("1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11",
        "12", "13", "14", "15", "16", "17", "18", "19",
        "20", "21", "22")

      sc(c => c.set("spark.local.dir", "/data/temp"))

      TPCHUtils.createAndLoadTables(snc, true)

      snc.sql(
        s"""CREATE INDEX idx_orders_cust ON orders(o_custkey)
             options (COLOCATE_WITH 'customer')
          """)

      snc.sql(
        s"""CREATE INDEX idx_lineitem_part ON lineitem(l_partkey)
             options (COLOCATE_WITH 'part')
          """)

      val tables = Seq("nation", "region", "supplier", "customer", "orders", "lineitem", "part",
        "partsupp")

      val tableSizes = tables.map { tableName =>
        (tableName, snc.table(tableName).count())
      }.toMap

      logInfo(tableSizes.mkString("\n"))
      runBenchmark("select o_orderkey from orders where o_orderkey = 1", tableSizes, 2)
      runBenchmark("select o_orderkey from orders where o_orderkey = 32", tableSizes)
      runBenchmark("select o_orderkey from orders where o_orderkey = 801", tableSizes)
      runBenchmark("select o_orderkey from orders where o_orderkey = 1409", tableSizes)
      // queries.foreach(q => benchmark(q, tableSizes))

    } finally {
      snc.sql(s"DROP INDEX if exists idx_orders_cust")
      snc.sql(s"DROP INDEX if exists idx_lineitem_part")
    }
  }

  private def togglePruning(onOff: Boolean, snc: SnappyContext) =
    Property.PartitionPruning.set(snc.conf, onOff)

  def runBenchmark(queryString: String, tableSizes: Map[String, Long], numSecs: Int = 0): Unit = {

    // This is an indirect hack to estimate the size of each query's input by traversing the
    // logical plan and adding up the sizes of all tables that appear in the plan. Note that this
    // currently doesn't take WITH subqueries into account which might lead to fairly inaccurate
    // per-row processing time for those cases.
    val queryRelations = scala.collection.mutable.HashSet[String]()
    snc.sql(queryString).queryExecution.logical.map {
      case ur@UnresolvedRelation(t: TableIdentifier, _) =>
        queryRelations.add(t.table.toLowerCase)
      case lp: LogicalPlan =>
        lp.expressions.foreach {
          _ foreach {
            case subquery: SubqueryExpression =>
              subquery.plan.foreach {
                case ur@UnresolvedRelation(t: TableIdentifier, _) =>
                  queryRelations.add(t.table.toLowerCase)
                case _ =>
              }
            case _ =>
          }
        }
      case _ =>
    }
    val size = queryRelations.map(tableSizes.getOrElse(_, 0L)).sum

    import scala.concurrent.duration._
    val b = new Benchmark(s"JoinOrder optimization", size,
      warmupTime = numSecs.seconds)
    b.addCase("WithOut Partition Pruning", numIters = 0,
      prepare = () => togglePruning(onOff = false, snc),
      cleanup = () => {})(_ => snc.sql(queryString).collect())
    b.addCase("With Partition Pruning", numIters = 0,
      prepare = () => togglePruning(onOff = true, snc),
      cleanup = () => {})(_ => snc.sql(queryString).collect())
    b.run()
  }

  def benchmark(qNum: String, tableSizes: Map[String, Long]): Unit = {

    val qryProvider = new TPCH with SnappyAdapter
    val query = qNum.toInt

    def executor(str: String) = snc.sql(str)

    val size = qryProvider.estimateSizes(query, tableSizes, executor)
    logInfo(s"$qNum size $size")
    val b = new Benchmark(s"JoinOrder optimization", size, minNumIters = 10)

    def case1(): Unit = snc.setConf(io.snappydata.Property.EnableExperimentalFeatures.name,
      "false")

    def case2(): Unit = snc.setConf(io.snappydata.Property.EnableExperimentalFeatures.name,
      "true")

    def case3(): Unit = {
      snc.setConf(io.snappydata.Property.EnableExperimentalFeatures.name,
        "true")
    }

//    def evalSnappyMods(genPlan: Boolean) = TPCH_Snappy.queryExecution(qNum, snc, useIndex = false,
//      genPlan = genPlan)._1.foreach(_ => ())

    var queryToBeExecuted = TPCH_Queries.getQuery(qNum, false, true)
    def evalSnappyMods(genPlan: Boolean) = QueryExecutor.queryExecution(
      qNum, queryToBeExecuted, snc, false)._1.foreach(_ => ())

    def evalBaseTPCH = qryProvider.execute(query, executor)._1.foreach(_ => ())

    //    b.addCase(s"$qNum baseTPCH index = F", prepare = case1)(i => evalBaseTPCH)
    //    b.addCase(s"$qNum baseTPCH joinOrder = T", prepare = case2)(i => evalBaseTPCH)
    b.addCase(s"$qNum without PartitionPruning", numIters = 0,
      prepare = () => togglePruning(onOff = false, snc),
      cleanup = () => {})(_ => evalSnappyMods(false))
    b.addCase(s"$qNum with PartitionPruning", numIters = 0,
      prepare = () => togglePruning(onOff = true, snc),
      cleanup = () => {})(_ => evalSnappyMods(false))
    /*
        b.addCase(s"$qNum snappyMods joinOrder = T", prepare = case2)(i => evalSnappyMods(false))
        b.addCase(s"$qNum baseTPCH index = T", prepare = case3)(i => evalBaseTPCH)
    */
    b.run()

  }

  test("northwind queries") {
    //    val sctx = sc(c => c.set("spark.sql.inMemoryColumnarStorage.batchSize", "40000"))
    //    val snc = getOrCreate(sctx)
    //    NorthWindDUnitTest.createAndLoadColumnTables(snc)
    //    val s = "select distinct shipcountry from orders"
    //    snc.sql(s).collect()
    //    NWQueries.assertJoin(snc, NWQueries.Q42, "Q42", 22, 1, classOf[LocalJoin])
    /*
        Thread.sleep(1000 * 60 * 60)
        NWQueries.assertJoin(snc, NWQueries.Q42, "Q42", 22, 1, classOf[LocalJoin])
    */
  }

}
